from defense.base_defense import BaseDefense
from src.helpers.utils import load_model
from src.helpers.utils import logger_setup
from default_config import ModelModes
import os
import torch

class HiFiCDefense(BaseDefense):
    """
    HiFiC Defense class for applying HiFiC to images.
    """
    def __init__(self, weights='low', device=None, iterations=1):
        """
        Initialize the HiFiC defense with a specified model weights and device.
        
        Args:
            weights (str): model weights to use for HiFiC (default: 'low').
            device (torch.device or str): Device to move the model to (default: None).
        """       
        super(HiFiCDefense, self).__init__(device,iterations)
        if weights not in ['low', 'med', 'hi']:
            raise ValueError("weights must be either 'low', 'med' or 'hi'.")
        self.weights = weights
        os.makedirs('logs', exist_ok=True)
        self.logger = logger_setup(logpath=os.path.join('logs/', 'logs'), filepath=os.path.abspath('logs/'))
        try:
            self.args,self.model,self.optimizers = load_model(f'data/hific_{weights}.pt',self.logger,self.device,model_mode=ModelModes.EVALUATION)
        except FileNotFoundError:
            raise FileNotFoundError(f"Model weights for '{weights}' not found. Please provide valid weights in the data folder.")
        self.device_attributes.append('model')

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            torch.Tensor: iterations times compressed and decompressed tensor.
        """
        for _ in range(self.iterations):
            x,_ = self.model(x)
            x = x.clamp(0, 1)
        return x